package com.sillyhat.project.core.interceptor;

import com.sillyhat.project.common.dto.DataTables;
import com.sillyhat.project.common.utils.ReflectHelper;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.MappedStatement.Builder;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.domain.Page;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

@Intercepts({
	@Signature(
		type=StatementHandler.class,
		method="prepare",
		args={
			Connection.class
		}
	)
})
public class SillyHatMybatisInterceptor implements Interceptor {

	private final Logger logger = LoggerFactory.getLogger(SillyHatMybatisInterceptor.class);

	private String dialect = ""; //数据库方言
	private String pageSqlId = ""; //mapper.xml中需要拦截的ID(正则匹配)
	/**
	 * 对于StatementHandler其实只有两个实现类，一个是RoutingStatementHandler，另一个是抽象类BaseStatementHandler，
	 * BaseStatementHandler有三个子类，分别是SimpleStatementHandler，PreparedStatementHandler和CallableStatementHandler，
	 * SimpleStatementHandler是用于处理Statement的，PreparedStatementHandler是处理PreparedStatement的，而CallableStatementHandler是
	 * 处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler，而在RoutingStatementHandler里面拥有一个
	 * StatementHandler类型的delegate属性，RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler，即SimpleStatementHandler、
	 * PreparedStatementHandler或CallableStatementHandler，在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。
	 * 我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法，又因为Mybatis只有在建立RoutingStatementHandler的时候
	 * 是通过Interceptor的plugin方法进行包裹的，所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。
	 * @param invocation
	 * @return
	 * @throws Throwable
	 */
	public Object intercept(Invocation invocation) throws Throwable {
		if(invocation.getTarget() instanceof RoutingStatementHandler){
			RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();
			StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
			BoundSql boundSql = delegate.getBoundSql();
			Object obj = boundSql.getParameterObject();
			if (obj instanceof DataTables) {
				DataTables dataTables = (DataTables) obj;
				//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
				MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");
				//拦截到的prepare方法参数是一个Connection对象
				Connection connection = (Connection)invocation.getArgs()[0];
				//获取当前要执行的Sql语句，也就是我们直接在Mapper映射语句中写的Sql语句
				String sql = boundSql.getSql();
				//给当前的page参数对象设置总记录数
				this.setTotalRecord(dataTables, mappedStatement, connection);
				//获取分页Sql语句
				String pageSql = this.getPageSql(dataTables, sql);
				//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
				ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
			}
		}
		// 当前环境 MappedStatement，BoundSql，及sql取得
//		MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
//		Object parameter = invocation.getArgs()[1];
//		BoundSql boundSql = mappedStatement.getBoundSql(parameter);
//		String originalSql = boundSql.getSql().trim();
//		logger.info("SQL:-----------------begin---------------------");
//		logger.info("\n"+originalSql);
//		logger.info("SQL:-----------------end---------------------");
//		Object parameterObject = boundSql.getParameterObject();
		// Page对象获取，“信使”到达拦截器！
//		Paging page = searchPageWithXpath(boundSql.getParameterObject(), ".","page", "*/page");
//		if (page != null) {
//			// Page对象存在的场合，开始分页处理
//			String countSql = getCountSql(originalSql);
//			originalSql = changeSql(page, originalSql);
//			String orderDirection = page.getOrderDirection();
//			String orderField = page.getOrderField();
//			if(orderField != null && !"".equals(orderField)){
//				if(originalSql.indexOf("order") != -1){
//					originalSql = originalSql.substring(0, originalSql.indexOf("order")) + " order by " + orderField + " " + orderDirection;
//				}else{
//					originalSql += " order by " + orderField + " " + orderDirection;
//				}
//			}
//			Connection connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();
//			PreparedStatement countStmt = connection.prepareStatement(countSql);
//			BoundSql countBS = copyFromBoundSql(mappedStatement, boundSql,countSql);
//			DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBS);
//			parameterHandler.setParameters(countStmt);
//			ResultSet rs = countStmt.executeQuery();
//			int totpage = 0;
//			if (rs.next()) {
//				totpage = rs.getInt(1);
//			}
//			rs.close();
//			countStmt.close();
//			connection.close();
//			// 分页计算
//			int pageNo = page.getPageNo();//当前页码
//			int pageSize = page.getPageSize();//每页行数
//			page.setTotalRecord(totpage);
//			page.setTotalPage(getLastPage(totpage, pageSize));
//			StringBuffer sql = null;
//			if(ReadDatabaseProperties.getValueByKey(Constants.DATABASE_0_CONNECTION_SORT).equals(Constants.DATABASE_SORT_ORACLE)){
//				sql = getOracleSql(originalSql, pageNo,pageSize,totpage);
//			}else if(ReadDatabaseProperties.getValueByKey(Constants.DATABASE_0_CONNECTION_SORT).equals(Constants.DATABASE_SORT_MYSQL)){
//				sql = getMysqlSql(originalSql, pageNo,pageSize,totpage);
//			}
//			logger.debug("SQL:-----------------format begin---------------------");
//			logger.debug("\n"+sql);
//			logger.debug("SQL:-----------------format end---------------------");
//			BoundSql newBoundSql = copyFromBoundSql(mappedStatement, boundSql,sql.toString());
//			MappedStatement newMs = copyFromMappedStatement(mappedStatement,new BoundSqlSqlSource(newBoundSql));
//			invocation.getArgs()[0] = newMs;
//		}
		return invocation.proceed();
	}


	/**
	 * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle
	 * 其它的数据库都 没有进行分页
	 *
	 * @param page 分页对象
	 * @param sql 原sql语句
	 * @return
	 */
	private String getPageSql(DataTables dataTables, String sql) {
		StringBuffer sqlBuffer = new StringBuffer(sql);
		if ("mysql".equalsIgnoreCase(dialect)) {
			return getMysqlPageSql(dataTables, sqlBuffer);
		} else if ("oracle".equalsIgnoreCase(dialect)) {
//			return getOraclePageSql(dataTables, sqlBuffer);
		}
		return sqlBuffer.toString();
	}

	/**
	 * 获取Mysql数据库的分页查询语句
	 * @param page 分页对象
	 * @param sqlBuffer 包含原sql语句的StringBuffer对象
	 * @return Mysql数据库分页语句
	 */
	private String getMysqlPageSql(DataTables dataTables, StringBuffer sqlBuffer) {
		//计算第一条记录的位置，Mysql中记录的位置是从0开始的
		sqlBuffer.append("\n\r limit ").append(dataTables.getStart()).append(",").append(dataTables.getLength());
		return sqlBuffer.toString();
	}

	/**
	 * 获取Oracle数据库的分页查询语句
	 * @param page 分页对象
	 * @param sqlBuffer 包含原sql语句的StringBuffer对象
	 * @return Oracle数据库的分页查询语句
	 */
	private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
		//计算第一条记录的位置，Oracle分页是通过rownum进行的，而rownum是从1开始的
//		int offset = (page.getPage() - 1) * page.getRows() + 1;
//		sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getRows());
//		sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
		//上面的Sql语句拼接之后大概是这个样子：
		//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
		return sqlBuffer.toString();
	}

	/**
	 * 给当前的参数对象page设置总记录数
	 *
	 * @param page Mapper映射语句对应的参数对象
	 * @param mappedStatement Mapper映射语句
	 * @param connection 当前的数据库连接
	 */
	private void setTotalRecord(DataTables dataTables, MappedStatement mappedStatement, Connection connection) {
		//获取对应的BoundSql，这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
		//delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
		BoundSql boundSql = mappedStatement.getBoundSql(dataTables);
		//获取到我们自己写在Mapper映射语句中对应的Sql语句
		String sql = boundSql.getSql();
		//通过查询Sql语句获取到对应的计算总记录数的sql语句
		String countSql = this.getCountSql(sql);
		//通过BoundSql获取对应的参数映射
		List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
		//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
		BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, dataTables);
		//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
		ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, dataTables, countBoundSql);
		//通过connection建立一个countSql对应的PreparedStatement对象。
		PreparedStatement pstmt = null;
		ResultSet rs = null;
		try {
			pstmt = connection.prepareStatement(countSql);
			//通过parameterHandler给PreparedStatement对象设置参数
			parameterHandler.setParameters(pstmt);
			//之后就是执行获取总记录数的Sql语句和获取结果了。
			rs = pstmt.executeQuery();
			if (rs.next()) {
				int totalRecord = rs.getInt(1);
				//给当前的参数page对象设置总记录数
				dataTables.setRecordsTotal(totalRecord);
				dataTables.setRecordsFiltered(totalRecord);
			}
		} catch (SQLException e) {
			logger.error("封装查询总记录数发生异常",e);
		} finally {
			try {
				if (rs != null)
					rs.close();
				if (pstmt != null)
					pstmt.close();
			} catch (SQLException e) {
				logger.error("关闭数据库连接发生异常");
			}
		}
	}

	/**
	 * 根据原Sql语句获取对应的查询总记录数的Sql语句
	 * @param sql
	 * @return
	 */
	private String getCountSql(String sql) {
//		int index = sql.indexOf("from");
//		if(index == -1){
//			index = sql.indexOf("FROM");
//		}
		return "select count(1) from (\n\r" + sql + "\n\r) t";
	}

	/**
	 * <p>Title: getOracleSql</p>
	 * <p>Description: </p>TODO
	 * @param @param originalSql	原始SQL
	 * @param @param pageNo	当前页码
	 * @param @param pageSize	每页行数
	 * @param @param totalRecord	总记录数
	 * @author 徐士宽
	 * @date 2015-4-10
	 * @return:StringBuffer
	 */
	private StringBuffer getOracleSql(String originalSql,int pageNo,int pageSize,int totalRecord){
		int startRecord = pageNo * pageSize - pageSize + 1;
		int endRecord = pageNo * pageSize;
		StringBuffer sql = new StringBuffer();
		sql.append("select * from (select t.*, rownum rn from (").append(originalSql);
		//判断排序
		sql.append("");
		sql.append(") t where rownum <= ").append(endRecord).append(") where rn >= ").append(startRecord);
		return sql;
	}


	public String getDialect() {
		return dialect;
	}

	public void setDialect(String dialect) {
		this.dialect = dialect;
	}

	public String getPageSqlId() {
		return pageSqlId;
	}

	public void setPageSqlId(String pageSqlId) {
		this.pageSqlId = pageSqlId;
	}

	/**
	 * 拦截器对应的封装原始对象的方法
	 */
	public Object plugin(Object arg0) {
		// TODO Auto-generated method stub
		if (arg0 instanceof StatementHandler) {
			return Plugin.wrap(arg0, this);
		} else {
			return arg0;
		}
	}

	/**
	 * 设置注册拦截器时设定的属性
	 */
	public void setProperties(Properties properties) {
		setDialect(properties.getProperty("dialect"));
	}
}
